package com.hanxiaozhang.sqlexecute;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLLimit;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLIntegerExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.parser.ParserException;
import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.util.List;
import java.util.Map;

/**
 * 〈一句话功能简述〉<br>
 * 〈执行Sql服务，基于Mybatis+使用Druid〉
 *
 * @author hanxinghua
 * @create 2023/6/14
 * @since 1.0.0
 */
@Slf4j
@Service
public class SqlExecuteServiceImpl implements SqlExecuteService {


    private static final int DEFAULT_LIMIT = 1000;

    @Resource
    private SqlExecuteDao sqlExecuteDao;

    @Autowired
    private DataSourceProperties dataSourceProperties;


    /**
     * 执行
     * <p>
     * Tips：
     * 1. 可以增加权限校验。
     * 2. 支持数据备份。
     *
     * @param sql
     * @return
     */
    @Override
    public String execute(String sql) {

        String result = null;
        // 校验SQL语法
        SQLStatement statement = checkSyntax(sql);
        try {
            if (statement instanceof SQLInsertStatement) {
                result = String.valueOf(sqlExecuteDao.insert(sql));
            }
            if (statement instanceof SQLDeleteStatement) {
                result = String.valueOf(sqlExecuteDao.delete(sql));
            }
            if (statement instanceof SQLSelectStatement) {
                setQueryLimit((SQLSelectStatement) statement, DEFAULT_LIMIT);
                List<Map<String, Object>> list = sqlExecuteDao.select(SQLUtils.toMySqlString(statement));
                result = JSON.toJSONString(list);
            }
            if (statement instanceof SQLUpdateStatement) {
                result = String.valueOf(sqlExecuteDao.update(sql));
            }
        } catch (Exception e) {
            log.error("执行Sql异常，异常信息：[{}]", e);
            throw new RuntimeException("执行Sql异常!");
        }
        return result;
    }


    /**
     * 设置查询limit
     *
     * @param statement
     * @param limitNum
     */
    private void setQueryLimit(SQLSelectStatement statement, Integer limitNum) {
        SQLSelectQuery query = statement.getSelect().getQuery();
        if (query instanceof SQLSelectQueryBlock) {
            // 单表SQL查询（包括管理查询）
            SQLSelectQueryBlock select = (SQLSelectQueryBlock) query;
            if (select.getLimit() == null) {
                select.setLimit(new SQLLimit(new SQLIntegerExpr(limitNum)));
            }
        } else if (query instanceof SQLUnionQuery) {
            // 联合表SQL查询
            SQLUnionQuery select = (SQLUnionQuery) query;
            if (select.getLimit() == null) {
                select.setLimit(new SQLLimit(new SQLIntegerExpr(limitNum)));
            }
        }
    }


    /**
     * 校验SQL语法
     *
     * @param sql
     * @return
     */
    private SQLStatement checkSyntax(String sql) {
        List<SQLStatement> sqlStatements = SQLUtils.parseStatements(sql, dataSourceProperties.getDriverClassName());
        if (sqlStatements.size() > 1) {
            throw new ParserException("每次只允许执行一个sql语句！");
        }
        SQLStatement statement = sqlStatements.get(0);
        return statement;
    }


}
